{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hXIJ7Mvizsm-"
   },
   "source": [
    "## Train/Evaluate a DNN for AMC\n",
    "\n",
    "This notebook demonstrates how to create a neural network and a trainer in PyTorch to learn a signal classification task.  The reference dataset used is the RML2016.10A dataset for Automatic Modulation Classification\n",
    "\n",
    "### Assumptions\n",
    "- The dataset wrangling has already been completed (and is provided here)\n",
    "- The classifier evaluation code (and the plotting) has already been completed\n",
    "\n",
    "### Components Recreated in Tutorial\n",
    "- Deep Neural Network Model defined in PyTorch\n",
    "- Training Loop that trains for *n* epochs\n",
    "\n",
    "### See Also\n",
    "The code in this tutorial is a stripped down version of the code in ``rfml.nn.model.CNN`` and ``rfml.nn.train.StandardTrainingStrategy`` that simplifies discussion.  Further detail can be provided by directly browsing the source files for those classes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fWuzJnD5zsm_"
   },
   "source": [
    "## Install the library code and dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_CYvIkdwzsnA"
   },
   "outputs": [],
   "source": [
    "# Install the library code\n",
    "!pip install git+https://github.com/brysef/rfml.git@1.0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xZXJtKgLzsnE"
   },
   "outputs": [],
   "source": [
    "# Plotting Includes\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "# External Includes\n",
    "import numpy as np\n",
    "from pprint import pprint\n",
    "\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Internal Includes\n",
    "from rfml.data import Dataset, Encoder\n",
    "from rfml.data.converters import load_RML201610A_dataset\n",
    "\n",
    "from rfml.nbutils import plot_acc_vs_snr, plot_confusion, plot_convergence, plot_IQ\n",
    "\n",
    "from rfml.nn.eval import compute_accuracy, compute_accuracy_on_cross_sections, compute_confusion\n",
    "from rfml.nn.model import Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "yn08CS2vzsnG"
   },
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "vv27phvSzsnH"
   },
   "outputs": [],
   "source": [
    "gpu = True       # Set to True to use a GPU for training\n",
    "fig_dir = None   # Set to a file path if you'd like to save the plots generated\n",
    "data_path = None # Set to a file path if you've downloaded RML2016.10A locally"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AgsYbed4zsnJ"
   },
   "source": [
    "### Loading a Dataset\n",
    "\n",
    "The dataset used is downloaded from DeepSig Inc. and provided under a Creative Commons lic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 221
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 8693,
     "status": "ok",
     "timestamp": 1563489014772,
     "user": {
      "displayName": "Bryse Flowers",
      "photoUrl": "https://lh3.googleusercontent.com/-dfqbkjv2TMY/AAAAAAAAAAI/AAAAAAAAAAg/nmDc_6TAqq4/s64/photo.jpg",
      "userId": "13954389111173404421"
     },
     "user_tz": 240
    },
    "id": "pYEitUXfzsnJ",
    "outputId": "aa2f18a8-f9be-4d77-fd39-562ae6181bd0"
   },
   "outputs": [],
   "source": [
    "dataset = load_RML201610A_dataset(path=data_path)\n",
    "print(len(dataset))\n",
    "pprint(dataset.get_examples_per_class())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 765
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 9907,
     "status": "ok",
     "timestamp": 1563489015997,
     "user": {
      "displayName": "Bryse Flowers",
      "photoUrl": "https://lh3.googleusercontent.com/-dfqbkjv2TMY/AAAAAAAAAAI/AAAAAAAAAAg/nmDc_6TAqq4/s64/photo.jpg",
      "userId": "13954389111173404421"
     },
     "user_tz": 240
    },
    "id": "ZNvKWievzsnO",
    "outputId": "63dd8f28-b69e-4df9-d3fb-966a0aad4206"
   },
   "outputs": [],
   "source": [
    "train, test = dataset.split(frac=0.3, on=[\"Modulation\", \"SNR\"])\n",
    "train, val = train.split(frac=0.05, on=[\"Modulation\", \"SNR\"])\n",
    "\n",
    "print(\"Training Examples\")\n",
    "print(\"=================\")\n",
    "pprint(train.get_examples_per_class())\n",
    "print(\"=================\")\n",
    "print()\n",
    "print(\"Validation Examples\")\n",
    "print(\"=================\")\n",
    "pprint(val.get_examples_per_class())\n",
    "print(\"=================\")\n",
    "print()\n",
    "print(\"Testing Examples\")\n",
    "print(\"=================\")\n",
    "pprint(test.get_examples_per_class())\n",
    "print(\"=================\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 221
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 9899,
     "status": "ok",
     "timestamp": 1563489015999,
     "user": {
      "displayName": "Bryse Flowers",
      "photoUrl": "https://lh3.googleusercontent.com/-dfqbkjv2TMY/AAAAAAAAAAI/AAAAAAAAAAg/nmDc_6TAqq4/s64/photo.jpg",
      "userId": "13954389111173404421"
     },
     "user_tz": 240
    },
    "id": "xPb0PiN9zsnR",
    "outputId": "479bb7fb-0b40-4216-a03f-1741b2a30400"
   },
   "outputs": [],
   "source": [
    "le = Encoder([\"WBFM\",\n",
    "              \"AM-DSB\",\n",
    "              \"AM-SSB\",\n",
    "              \"CPFSK\",\n",
    "              \"GFSK\",\n",
    "              \"BPSK\",\n",
    "              \"QPSK\",\n",
    "              \"8PSK\",\n",
    "              \"PAM4\",\n",
    "              \"QAM16\",\n",
    "              \"QAM64\"],\n",
    "             label_name=\"Modulation\")\n",
    "print(le)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 350
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 10198,
     "status": "ok",
     "timestamp": 1563489016311,
     "user": {
      "displayName": "Bryse Flowers",
      "photoUrl": "https://lh3.googleusercontent.com/-dfqbkjv2TMY/AAAAAAAAAAI/AAAAAAAAAAg/nmDc_6TAqq4/s64/photo.jpg",
      "userId": "13954389111173404421"
     },
     "user_tz": 240
    },
    "id": "Xd0zIGl7zsnT",
    "outputId": "aedeb7a5-d2c8-4d8b-9196-d5ee5a32bc6e"
   },
   "outputs": [],
   "source": [
    "# Plot a sample of the data\n",
    "# You can choose a different sample by changing\n",
    "idx = 10\n",
    "snr = 18.0\n",
    "modulation = \"8PSK\"\n",
    "\n",
    "mask = (dataset.df[\"SNR\"] == snr) & (dataset.df[\"Modulation\"] == modulation)\n",
    "sample = dataset.as_numpy(mask=mask, le=le)[0][idx,0,:]\n",
    "t = np.arange(sample.shape[1])\n",
    "\n",
    "title = \"{modulation} Sample at {snr:.0f} dB SNR\".format(modulation=modulation, snr=snr)\n",
    "fig = plot_IQ(iq=sample, title=title)\n",
    "if fig_dir is not None:\n",
    "    file_path = \"{fig_dir}/{modulation}_{snr:.0f}dB_sample.pdf\".format(fig_dir=fig_dir,\n",
    "                                                                       modulation=modulation,\n",
    "                                                                       snr=snr)\n",
    "    print(\"Saving Figure -> {file_path}\".format(file_path=file_path))\n",
    "    fig.savefig(file_path, format=\"pdf\", transparent=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WPWtP6dJzsnV"
   },
   "source": [
    "### Creating a Neural Network Model\n",
    "We are going to recreate a Convolutional Neural Network (CNN) based on the \"VT_CNN2\" Architecture.  This network is based off of a network for modulation classification first\n",
    "introduced in [O'Shea et al.] and later updated by [West/Oshea] and [Hauser et al.]\n",
    "to have larger filter sizes.\n",
    "\n",
    "#### Citations\n",
    "##### O'Shea et al.\n",
    "T. J. O'Shea, J. Corgan, and T. C. Clancy, “Convolutional radio modulation recognition networks,” in International Conference on Engineering Applications of Neural Networks, pp. 213–226, Springer,2016.\n",
    "\n",
    "##### West/O'Shea\n",
    "N. E. West and T. O’Shea, “Deep architectures for modulation recognition,” in IEEE International Symposium on Dynamic Spectrum Access Networks (DySPAN), pp. 1–6, IEEE, 2017.\n",
    "\n",
    "##### Hauser et al.\n",
    "S. C. Hauser, W. C. Headley, and A. J.  Michaels, “Signal detection effects on deep neural networks utilizing raw iq for modulation classification,” in Military Communications Conference, pp. 121–127, IEEE, 2017."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "QHzcQJGyzsnY"
   },
   "outputs": [],
   "source": [
    "class MyCNN(Model):\n",
    "    def __init__(self, input_samples: int, n_classes: int):\n",
    "        super().__init__(input_samples=input_samples, n_classes=n_classes)\n",
    "        # Batch x 1-channel x IQ x input_samples\n",
    "        # Modifying the first convolutional layer to not use a bias term is a\n",
    "        # modification made by Bryse Flowers due to the observation of vanishing\n",
    "        # gradients during training when ported to PyTorch (other authors used\n",
    "        # Keras).\n",
    "        self.conv1 = nn.Conv2d(\n",
    "            in_channels=1,\n",
    "            out_channels=256,\n",
    "            kernel_size=(1, 7),\n",
    "            padding=(0, 3),\n",
    "            bias=False,\n",
    "        )\n",
    "        self.a1 = nn.ReLU()\n",
    "        self.n1 = nn.BatchNorm2d(256)\n",
    "\n",
    "        self.conv2 = nn.Conv2d(\n",
    "            in_channels=256,\n",
    "            out_channels=80,\n",
    "            kernel_size=(2, 7),\n",
    "            padding=(0, 3),\n",
    "            bias=True,\n",
    "        )\n",
    "        self.a2 = nn.ReLU()\n",
    "        self.n2 = nn.BatchNorm2d(80)\n",
    "\n",
    "        # Batch x Features\n",
    "        self.dense1 = nn.Linear(80 * 1 * input_samples, 256)\n",
    "        self.a3 = nn.ReLU()\n",
    "        self.n3 = nn.BatchNorm1d(256)\n",
    "\n",
    "        self.dense2 = nn.Linear(256, n_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.a1(x)\n",
    "        x = self.n1(x)\n",
    "\n",
    "        x = self.conv2(x)\n",
    "        x = self.a2(x)\n",
    "        x = self.n2(x)\n",
    "\n",
    "        # Flatten the input layer down to 1-d by using Tensor operations\n",
    "        x = x.contiguous()\n",
    "        x = x.view(x.size()[0], -1)\n",
    "\n",
    "        x = self.dense1(x)\n",
    "        x = self.a3(x)\n",
    "        x = self.n3(x)\n",
    "\n",
    "        x = self.dense2(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 323
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 10169,
     "status": "ok",
     "timestamp": 1563489016317,
     "user": {
      "displayName": "Bryse Flowers",
      "photoUrl": "https://lh3.googleusercontent.com/-dfqbkjv2TMY/AAAAAAAAAAI/AAAAAAAAAAg/nmDc_6TAqq4/s64/photo.jpg",
      "userId": "13954389111173404421"
     },
     "user_tz": 240
    },
    "id": "5XcGfWWEzsnZ",
    "outputId": "77e3826e-b564-4348-e31a-54350528f82d"
   },
   "outputs": [],
   "source": [
    "model = MyCNN(input_samples=128, n_classes=11)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "cdWw1puezsnb"
   },
   "source": [
    "### Implementing a Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "I39flOAMzsnb"
   },
   "outputs": [],
   "source": [
    "class MyTrainingStrategy(object):\n",
    "\n",
    "    def __init__(self, lr: float = 10e-4, n_epochs: int = 3, gpu: bool = True):\n",
    "        self.lr = lr\n",
    "        self.n_epochs = n_epochs\n",
    "        self.gpu = gpu\n",
    "\n",
    "    def __repr__(self):\n",
    "        ret = self.__class__.__name__\n",
    "        ret += \"(lr={}, n_epochs={}, gpu={})\".format(self.lr, self.n_epochs, self.gpu)\n",
    "        return ret\n",
    "        \n",
    "    def __call__(\n",
    "        self, model: nn.Module, training: Dataset, validation: Dataset, le: Encoder\n",
    "    ):\n",
    "        criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "        if self.gpu:\n",
    "            model.cuda()\n",
    "            criterion.cuda()\n",
    "\n",
    "        optimizer = Adam(model.parameters(), lr=self.lr)\n",
    "\n",
    "        train_data = DataLoader(\n",
    "            training.as_torch(le=le), shuffle=True, batch_size=512\n",
    "        )\n",
    "        val_data = DataLoader(\n",
    "            validation.as_torch(le=le), shuffle=True, batch_size=512\n",
    "        )\n",
    "\n",
    "        # Save two lists for plotting a convergence graph at the end\n",
    "        ret_train_loss = list()\n",
    "        ret_val_loss = list()\n",
    "        \n",
    "        for epoch in range(self.n_epochs):\n",
    "            train_loss = self._train_one_epoch(\n",
    "                model=model, data=train_data, loss_fn=criterion, optimizer=optimizer\n",
    "            )\n",
    "            print(\"On Epoch {} the training loss was {}\".format(epoch, train_loss))\n",
    "            ret_train_loss.append(train_loss)\n",
    "\n",
    "            val_loss = self._validate_once(\n",
    "                model=model, data=val_data, loss_fn=criterion\n",
    "            )\n",
    "            print(\"---- validation loss was {}\".format(val_loss))\n",
    "            ret_val_loss.append(val_loss)\n",
    "\n",
    "        return ret_train_loss, ret_val_loss\n",
    "\n",
    "    def _train_one_epoch(\n",
    "        self, model: nn.Module, data: DataLoader, loss_fn: nn.CrossEntropyLoss, optimizer: Adam\n",
    "    ) -> float:\n",
    "        total_loss = 0.0\n",
    "        # Switch the model mode so it remembers gradients, induces dropout, etc.\n",
    "        model.train()\n",
    "\n",
    "        for i, batch in enumerate(data):\n",
    "            x, y = batch\n",
    "\n",
    "            # Push data to GPU if necessary\n",
    "            if self.gpu:\n",
    "                x = Variable(x.cuda())\n",
    "                y = Variable(y.cuda())\n",
    "            else:\n",
    "                x = Variable(x)\n",
    "                y = Variable(y)\n",
    "\n",
    "            # Forward pass of prediction\n",
    "            outputs = model(x)\n",
    "\n",
    "            # Zero out the parameter gradients, because they are cumulative,\n",
    "            # compute loss, compute gradients (backward), update weights\n",
    "            loss = loss_fn(outputs, y)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            total_loss += loss.item()\n",
    "\n",
    "        mean_loss = total_loss / (i + 1.0)\n",
    "        return mean_loss\n",
    "\n",
    "    def _validate_once(\n",
    "        self, model: nn.Module, data: DataLoader, loss_fn: nn.CrossEntropyLoss\n",
    "    ) -> float:\n",
    "        total_loss = 0.0\n",
    "        # Switch the model back to test mode (so that batch norm/dropout doesn't\n",
    "        # take effect)\n",
    "        model.eval()\n",
    "        for i, batch in enumerate(data):\n",
    "            x, y = batch\n",
    "\n",
    "            if self.gpu:\n",
    "                x = x.cuda()\n",
    "                y = y.cuda()\n",
    "\n",
    "            outputs = model(x)\n",
    "            loss = loss_fn(outputs, y)\n",
    "            total_loss += loss.item()\n",
    "\n",
    "        mean_loss = total_loss / (i + 1.0)\n",
    "        return mean_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 10154,
     "status": "ok",
     "timestamp": 1563489016320,
     "user": {
      "displayName": "Bryse Flowers",
      "photoUrl": "https://lh3.googleusercontent.com/-dfqbkjv2TMY/AAAAAAAAAAI/AAAAAAAAAAg/nmDc_6TAqq4/s64/photo.jpg",
      "userId": "13954389111173404421"
     },
     "user_tz": 240
    },
    "id": "WajMNLpTzsnc",
    "outputId": "94989e07-f1b6-4e6c-b5a8-50e6338b8972"
   },
   "outputs": [],
   "source": [
    "trainer = MyTrainingStrategy(gpu=gpu)\n",
    "print(trainer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "69T8pBd0zsne"
   },
   "source": [
    "### Putting it All Together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 234194,
     "status": "ok",
     "timestamp": 1563489240373,
     "user": {
      "displayName": "Bryse Flowers",
      "photoUrl": "https://lh3.googleusercontent.com/-dfqbkjv2TMY/AAAAAAAAAAI/AAAAAAAAAAg/nmDc_6TAqq4/s64/photo.jpg",
      "userId": "13954389111173404421"
     },
     "user_tz": 240
    },
    "id": "at04fc4Hzsne",
    "outputId": "182088e4-af59-4738-f2c0-943fc9738b20"
   },
   "outputs": [],
   "source": [
    "train_loss, val_loss = trainer(model=model,\n",
    "                               training=train,\n",
    "                               validation=val,\n",
    "                               le=le)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 350
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 234190,
     "status": "ok",
     "timestamp": 1563489240380,
     "user": {
      "displayName": "Bryse Flowers",
      "photoUrl": "https://lh3.googleusercontent.com/-dfqbkjv2TMY/AAAAAAAAAAI/AAAAAAAAAAg/nmDc_6TAqq4/s64/photo.jpg",
      "userId": "13954389111173404421"
     },
     "user_tz": 240
    },
    "id": "MRO2b-Ndzsnf",
    "outputId": "c0ec81d3-778e-4088-9357-c03d47aa2894"
   },
   "outputs": [],
   "source": [
    "title = \"Training Results of {model_name} on {dataset_name}\".format(model_name=\"MyCNN\", dataset_name=\"RML2016.10A\")\n",
    "fig = plot_convergence(train_loss=train_loss, val_loss=val_loss, title=title)\n",
    "if fig_dir is not None:\n",
    "    file_path = \"{fig_dir}/training_loss.pdf\"\n",
    "    print(\"Saving Figure -> {file_path}\".format(file_path=file_path))\n",
    "    fig.savefig(file_path, format=\"pdf\", transparent=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "XbwKjYygzsnh"
   },
   "source": [
    "### Testing the Trained Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 249181,
     "status": "ok",
     "timestamp": 1563489255384,
     "user": {
      "displayName": "Bryse Flowers",
      "photoUrl": "https://lh3.googleusercontent.com/-dfqbkjv2TMY/AAAAAAAAAAI/AAAAAAAAAAg/nmDc_6TAqq4/s64/photo.jpg",
      "userId": "13954389111173404421"
     },
     "user_tz": 240
    },
    "id": "uAZ3mZEzzsnh",
    "outputId": "3492e760-36a1-4410-bcec-ed281baec99e"
   },
   "outputs": [],
   "source": [
    "acc = compute_accuracy(model=model, data=test, le=le)\n",
    "print(\"Overall Testing Accuracy: {:.4f}\".format(acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 350
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 264339,
     "status": "ok",
     "timestamp": 1563489270553,
     "user": {
      "displayName": "Bryse Flowers",
      "photoUrl": "https://lh3.googleusercontent.com/-dfqbkjv2TMY/AAAAAAAAAAI/AAAAAAAAAAg/nmDc_6TAqq4/s64/photo.jpg",
      "userId": "13954389111173404421"
     },
     "user_tz": 240
    },
    "id": "v40LSsC4zsni",
    "outputId": "548992da-3b6e-47cd-e5e1-4dab8f149a26"
   },
   "outputs": [],
   "source": [
    "acc_vs_snr, snr = compute_accuracy_on_cross_sections(model=model,\n",
    "                                                     data=test,\n",
    "                                                     le=le,\n",
    "                                                     column=\"SNR\")\n",
    "\n",
    "title = \"Accuracy vs SNR of {model_name} on {dataset_name}\".format(model_name=\"MyCNN\", dataset_name=\"RML2016.10A\")\n",
    "fig = plot_acc_vs_snr(acc_vs_snr=acc_vs_snr, snr=snr, title=title)\n",
    "if fig_dir is not None:\n",
    "    file_path = \"{fig_dir}/acc_vs_snr.pdf\"\n",
    "    print(\"Saving Figure -> {file_path}\".format(file_path=file_path))\n",
    "    fig.savefig(file_path, format=\"pdf\", transparent=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 380
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 280776,
     "status": "ok",
     "timestamp": 1563489287004,
     "user": {
      "displayName": "Bryse Flowers",
      "photoUrl": "https://lh3.googleusercontent.com/-dfqbkjv2TMY/AAAAAAAAAAI/AAAAAAAAAAg/nmDc_6TAqq4/s64/photo.jpg",
      "userId": "13954389111173404421"
     },
     "user_tz": 240
    },
    "id": "4V8j3KfYzsnj",
    "outputId": "aec4d7b6-0cd1-4de0-fe5b-d0f055fb7ac7"
   },
   "outputs": [],
   "source": [
    "cmn = compute_confusion(model=model, data=test, le=le)\n",
    "\n",
    "title = \"Confusion Matrix of {model_name} on {dataset_name}\".format(model_name=\"MyCNN\", dataset_name=\"RML2016.10A\")\n",
    "fig = plot_confusion(cm=cmn, labels=le.labels, title=title)\n",
    "if fig_dir is not None:\n",
    "    file_path = \"{fig_dir}/confusion_matrix.pdf\"\n",
    "    print(\"Saving Figure -> {file_path}\".format(file_path=file_path))\n",
    "    fig.savefig(file_path, format=\"pdf\", transparent=True)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "module2_solutions.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
